{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Callbacks\n",
    "\n",
    "> Callbacks which work with a learner's data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CollectDataCallback(Callback):\n",
    "    \"Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing\"\n",
    "    def before_fit(self): self.data = L()\n",
    "    def after_batch(self): self.data.append(to_detach((self.xb,self.yb,self.pred,self.loss)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class CudaCallback(Callback):\n",
    "    \"Move data to CUDA device\"\n",
    "    def __init__(self, device=None): self.device = ifnone(device, default_device())\n",
    "    def before_batch(self): self.learn.xb,self.learn.yb = to_device(self.xb),to_device(self.yb)\n",
    "    def before_fit(self): self.model.to(self.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You don't normally need to use this Callback, because fastai's `DataLoader` will handle passing data to a device for you. However, if you already have a plain PyTorch DataLoader and can't change it for some reason, you can use this transform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(#4) [0,18.003013610839844,12.680473327636719,'00:00']\n"
     ]
    }
   ],
   "source": [
    "#cuda\n",
    "learn = synth_learner(cbs=CudaCallback)\n",
    "learn.model\n",
    "learn.fit(1)\n",
    "test_eq(next(learn.model.parameters()).device.type, 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@log_args(but_as=TfmdDL.__init__)\n",
    "@delegates()\n",
    "class WeightedDL(TfmdDL):\n",
    "    def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        wgts = array([1.]*len(dataset) if wgts is None else wgts)\n",
    "        self.wgts = wgts/wgts.sum()\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.n==0: return []\n",
    "        if not self.shuffle: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.n, p=self.wgts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 160\n",
    "dsets = Datasets(torch.arange(n).float())\n",
    "dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)\n",
    "learn = synth_learner(data=dls, cbs=CollectDataCallback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(#4) [0,nan,None,'00:01']\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAO+klEQVR4nO3da4ycZ32G8euuTSjnAN7QYDu1i8zBEgRSE0Jb2kA42BHCQkJqwjklsiwRRFu1jSMk6oovBUqFEAbLSl2gUCIEKbipIW1JC61oIBtKQkwwLAnEi6FxSkUl+BAM/36Y1+kwzO6M7bFn8vj6SaOd93mefef2zOy9s+8cnKpCkvTg90vTDiBJmgwLXZIaYaFLUiMsdElqhIUuSY1YOa0LXrVqVa1bt25aFy9JD0q33nrrfVU1N2xuaoW+bt065ufnp3XxkvSglOQ7S815yEWSGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1YmShJ9mb5N4kdywxnyTvSbKQ5PYkF0w+piRplHEeoX8A2LzM/BZgQ3faBrz/5GNJko7XyEKvqs8DP1hmyVbgQ9VzM3B2knMnFVCSNJ5JHENfDRzq217sxn5Bkm1J5pPMHzlyZAIXfWLufOrTJr7Pz970pInv80Tt2n4T63b8w2m9zBO+vJ2POanL3bX9pgfOL3e7PnD7DLm8Sd4f3vW7L53Yvh5wEpl37tz5wHW0c+fO47rYE/2+fmPdJsexj6d/8Om/sI/++8BS9/3B7xvH4o5/O+7vWcpy979JmkShZ8jY0P8Gqar2VNWmqto0Nzf0owgkSSdoEoW+CKzt214DHJ7AfiVJx2EShb4PeG33apeLgB9W1fcmsF9J0nEY+WmLST4KXAysSrII/CnwEICq2g3sBy4FFoAfA1ecqrCSpKWNLPSqunzEfAFvnFgiSdIJ8Z2iktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEaMVehJNic5mGQhyY4h849J8vdJbktyIMkVk48qSVrOyEJPsgLYBWwBNgKXJ9k4sOyNwNeq6nzgYuBdSc6acFZJ0jLGeYR+IbBQVXdV1f3AdcDWgTUFPCpJgEcCPwCOTjSpJGlZ4xT6auBQ3/ZiN9bvvcDTgMPAV4E3V9XPBneUZFuS+STzR44cOcHIkqRhxin0DBmrge2XAF8Bngg8E3hvkkf/wjdV7amqTVW1aW5u7jijSpKWM06hLwJr+7bX0Hsk3u8K4PrqWQDuBp46mYiSpHGMU+i3ABuSrO+e6LwM2Dew5h7gEoAkTwCeAtw1yaCSpOWtHLWgqo4muQq4EVgB7K2qA0m2d/O7gbcBH0jyVXqHaK6uqvtOYW5J0oCRhQ5QVfuB/QNju/vOHwZePNlokqTj4TtFJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY0Yq9CTbE5yMMlCkh1LrLk4yVeSHEjyucnGlCSNsnLUgiQrgF3Ai4BF4JYk+6rqa31rzgbeB2yuqnuSnHOK8kqSljDOI/QLgYWququq7geuA7YOrHklcH1V3QNQVfdONqYkaZRxCn01cKhve7Eb6/dk4LFJ/jXJrUleO6mAkqTxjDzkAmTIWA3Zz68DlwAPA/4jyc1V9Y2f21GyDdgGcN555x1/WknSksZ5hL4IrO3bXgMcHrLmM1X1o6q6D/g8cP7gjqpqT1VtqqpNc3NzJ5pZkjTEOIV+C7AhyfokZwGXAfsG1nwKeF6SlUkeDjwHuHOyUSVJyxl5yKWqjia5CrgRWAHsraoDSbZ387ur6s4knwFuB34GXFtVd5zK4JKknzfOMXSqaj+wf2Bs98D2O4F3Ti6aJOl4+E5RSWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqxFiFnmRzkoNJFpLsWGbds5P8NMkrJhdRkjSOkYWeZAWwC9gCbAQuT7JxiXVvB26cdEhJ0mjjPEK/EFioqruq6n7gOmDrkHVvAj4B3DvBfJKkMY1T6KuBQ33bi93YA5KsBl4O7F5uR0m2JZlPMn/kyJHjzSpJWsY4hZ4hYzWw/W7g6qr66XI7qqo9VbWpqjbNzc2NGVGSNI6VY6xZBNb2ba8BDg+s2QRclwRgFXBpkqNV9clJhJQkjTZOod8CbEiyHvgucBnwyv4FVbX+2PkkHwBusMwl6fQaWehVdTTJVfRevbIC2FtVB5Js7+aXPW4uSTo9xnmETlXtB/YPjA0t8qp6/cnHkiQdL98pKkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWrEWIWeZHOSg0kWkuwYMv+qJLd3py8kOX/yUSVJyxlZ6ElWALuALcBG4PIkGweW3Q38TlU9A3gbsGfSQSVJyxvnEfqFwEJV3VVV9wPXAVv7F1TVF6rqf7rNm4E1k40pSRplnEJfDRzq217sxpbyBuDTwyaSbEsyn2T+yJEj46eUJI00TqFnyFgNXZg8n16hXz1svqr2VNWmqto0Nzc3fkpJ0kgrx1izCKzt214DHB5clOQZwLXAlqr678nEkySNa5xH6LcAG5KsT3IWcBmwr39BkvOA64HXVNU3Jh9TkjTKyEfoVXU0yVXAjcAKYG9VHUiyvZvfDbwVeDzwviQAR6tq06mLLUkaNM4hF6pqP7B/YGx33/krgSsnG02SdDx8p6gkNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDVirEJPsjnJwSQLSXYMmU+S93Tztye5YPJRJUnLGVnoSVYAu4AtwEbg8iQbB5ZtATZ0p23A+yecU5I0wjiP0C8EFqrqrqq6H7gO2DqwZivwoeq5GTg7ybkTzipJWkaqavkFySuAzVV1Zbf9GuA5VXVV35obgD+vqn/vtj8LXF1V8wP72kbvETzAU4CDk/qHTNgq4L5phxjBjCdv1vOBGSelpYy/WlVzwyZWjvHNGTI2+FtgnDVU1R5gzxiXOVVJ5qtq07RzLMeMJ2/W84EZJ+VMyTjOIZdFYG3f9hrg8AmskSSdQuMU+i3AhiTrk5wFXAbsG1izD3ht92qXi4AfVtX3JpxVkrSMkYdcqupokquAG4EVwN6qOpBkeze/G9gPXAosAD8Grjh1kU+LmT8shBknYdbzgRkn5YzIOPJJUUnSg4PvFJWkRljoktSIM7rQk6xN8i9J7kxyIMmbu/HHJfmnJN/svj52BrKuSPKf3Wv+Zy5jkrOTfDzJ17vr87kzmPEPutv5jiQfTfLL086YZG+Se5Pc0Te2ZKYk13QfsXEwyUummPGd3W19e5K/S3L2rGXsm/ujJJVk1bQyLpUvyZu6DAeSvOOk81XVGXsCzgUu6M4/CvgGvY83eAewoxvfAbx9BrL+IfC3wA3d9kxlBD4IXNmdPws4e5YyAquBu4GHddsfA14/7YzAbwMXAHf0jQ3N1N03bwMeCqwHvgWsmFLGFwMru/Nvn8WM3fhaei/o+A6waloZl7gOnw/8M/DQbvuck8132u64D4YT8CngRfTewXpuN3YucHDKudYAnwVe0FfoM5MReHRXlhkYn6WMq4FDwOPovbrrhq6Upp4RWDfwgz40E3ANcE3fuhuB504j48Dcy4GPzGJG4OPA+cC3+wp9KhmH3M4fA144ZN0J5zujD7n0S7IOeBbwReAJ1b2Ovvt6zhSjAbwb+BPgZ31js5Tx14AjwF93h4WuTfKIWcpYVd8F/gK4B/gevfdK/OMsZeyzVKZjv5SOWezGpu33gE9352cmY5KXAd+tqtsGpmYl45OB5yX5YpLPJXl2N37C+Sx0IMkjgU8Av19V/zvtPP2SvBS4t6punXaWZayk9+fk+6vqWcCP6B0qmBndceit9P6EfSLwiCSvnm6q4zbWR2ycTkneAhwFPnJsaMiy054xycOBtwBvHTY9ZGwa1+NK4LHARcAfAx9LEk4i3xlf6EkeQq/MP1JV13fD/3Xs0yK7r/dOKx/wm8DLknyb3iddviDJh5mtjIvAYlV9sdv+OL2Cn6WMLwTurqojVfUT4HrgN2Ys4zFLZZqpj9hI8jrgpcCrqjs2wOxkfBK9X963dT87a4AvJ/kVZifjInB99XyJ3l/gq04m3xld6N1vw78C7qyqv+yb2ge8rjv/OnrH1qeiqq6pqjVVtY7exy7cVFWvZrYyfh84lOQp3dAlwNeYoYz0DrVclOTh3e1+CXAns5XxmKUy7QMuS/LQJOvp/f8DX5pCPpJsBq4GXlZVP+6bmomMVfXVqjqnqtZ1PzuL9F4A8f1ZyQh8kt7zYiR5Mr0XE9x3UvlOx5MVs3oCfovenzK3A1/pTpcCj6f3JOQ3u6+Pm3bWLu/F/P+TojOVEXgmMN9dl5+k96fkrGX8M+DrwB3A39B7FcFUMwIfpXdM/yf0SucNy2WidxjhW/SeON0yxYwL9I7zHvu52T1rGQfmv033pOg0Mi5xHZ4FfLi7P34ZeMHJ5vOt/5LUiDP6kIsktcRCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY34P4odUNHcruvHAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(1)\n",
    "t = concat(*learn.collect_data.data.itemgot(0,0))\n",
    "plt.hist(t);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@log_args(but_as=TfmdDL.__init__)\n",
    "@delegates()\n",
    "class PartialDL(TfmdDL):\n",
    "    \"Select randomly partial quantity of data at each epoch\"\n",
    "    def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        self.partial_n = min(partial_n, self.n) if partial_n else None\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.partial_n is None: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.partial_n, replace=False))\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.partial_n is None: return super().__len__()\n",
    "        return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):\n",
    "    \"Create a partial dataloader `PartialDL` for the training set\"\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.partial_dataloaders(partial_n=32, bs=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(dls[0])==2\n",
    "for batch in dls[0]:\n",
    "    assert len(batch[0])==16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
